load("//bazel:api.bzl", "mojo_test")

_GPU_MEMORY = {
    "test_batch_kv_cache_flash_attention_causal_mask_ragged.mojo": "8",
}

_EXTRA_CONSTRAINTS = {
    "test_batch_kv_cache_flash_attention_causal_mask_ragged.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_batch_kv_cache_flash_attention.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_kv_cache_matmul.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_kv_cache_ragged_matmul.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_mha_mixed_ce_tg.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2366
    "test_batch_kv_cache_flash_attention_causal_mask.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_batch_kv_cache_flash_attention_causal_mask_ragged_paged.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_kv_cache_ragged_matmul_scale.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2596
}

[
    mojo_test(
        name = src + ".test",
        size = "large",
        srcs = [src],
        enable_assertions = select({
            "//:amd_gpu": False,
            "//conditions:default": src != "test_kv_cache_ragged_matmul.mojo",  # TODO: MSTDL-1147 understand why this test fails with asserts turned on.
        }),
        exec_properties = {
            "test.resources:gpu-memory": _GPU_MEMORY.get(src, "5"),
        },
        tags = ["gpu"],
        target_compatible_with = ["//:has_gpu"] + _EXTRA_CONSTRAINTS.get(src, []),
        deps = [
            "//max:internal_utils",
            "//max:kv_cache",
            "//max:linalg",
            "//max:nn",
            "//max:quantization",
            "@mojo//:stdlib",
        ],
    )
    for src in glob(["**/*.mojo"])
]
